import chainer
from chainer import Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer import cuda


import ot
import os
from sklearn.metrics import pairwise_distances

gpu_device = 0
cuda.get_device(gpu_device).use()

import numpy as np

from matplotlib import pyplot as plt

from datasets.synthetic import prepare_25gaussian_data, prepare_swissroll_data
from gen_models.toygen import Generator
from dis_models.toydis import Discriminator

import yaml
import source.yaml_utils as yaml_utils

def show_three_figures(y, ty1, ty2, X_train, xmin, xmax, ymin, ymax, save_path):
    plt.style.use('seaborn-darkgrid') 
    plt.figure(figsize=(20,5))

    plt.subplot(1, 4, 1)
    plt.xticks([])
    plt.yticks([])
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    plt.title("Training samples", fontsize=20)
    plt.scatter(X_train[:,:1], X_train[:,1:], alpha=0.5, color='gray', marker='o')
    
    plt.subplot(1, 4, 2)
    plt.xticks([])
    plt.yticks([])
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    plt.title("Samples by Generator", fontsize=20)
    y_d = y#.data
    plt.scatter(y_d[:,:1], y_d[:,1:], alpha=0.5, color='blue', marker='o', label='y')

    plt.subplot(1, 4, 3)
    plt.xticks([])
    plt.yticks([])
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    plt.title("DOT", fontsize=20)
    y_d = ty1#.data
    plt.scatter(y_d[:,:1], y_d[:,1:], alpha=0.5, color='red', marker='o', label='ty')
    
    plt.subplot(1, 4, 4)
    plt.xticks([])
    plt.yticks([])
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    plt.title("DLS", fontsize=20)
    y_d = ty2#.data
    plt.scatter(y_d[:,:1], y_d[:,1:], alpha=0.5, color='red', marker='o', label='ty')
    
    plt.savefig(save_path)
    plt.close('all')
    # plt.show()

def e_grad(z, P, gen, dis, alpha, ret_e=False):
    logp_z = F.sum(P.log_prob(z), 1, keepdims=True)
    x = gen(batchsize=z.shape[0], z=z)
    d = dis(x)
    E = -logp_z - alpha * d
    # E = - alpha * d
    grad = chainer.grad((E,), (z,))
    # prior_grad = chainer.grad((-logp_z, ), (z, ))
    # d_grad = chainer.grad((d, ), (z, ))
    # import pdb
    # pdb.set_trace()
    if ret_e:
        return E, grad
    return grad

def langevin_dynamics(z, gen, dis, alpha, n_steps, step_lr, eps_std):
    z_sp = []
    xp = gen.xp
    P = None
    batch_size, z_dim = z.shape
    if gen.distribution == "normal":
        P = chainer.distributions.Normal(xp.zeros((z_dim, ), dtype=xp.float32),
                                         xp.ones((z_dim, ), dtype=xp.float32))
    elif gen.distribution == 'uniform':
        P = chainer.distributions.Uniform(low=(xp.ones((z_dim, ), dtype=xp.float32) * -1.),
                                          high=xp.ones((z_dim, ), dtype=xp.float32))
    else:
        raise NotImplementedError(gen.distribution)
    prev_e = None
    b_lr = step_lr
    e_lr = 1e-6
    decay_steps = 100

    for _ in range(n_steps):
        if _ % 5 == 0:
            z_sp.append(z)
        eps = eps_std * xp.random.randn(batch_size, z_dim).astype(xp.float32)
        # eps = xp.sqrt(step_lr) * xp.random.randn(batch_size, z_dim).astype(xp.float32)
        # import pdb
        # pdb.set_trace()
        E, grad = e_grad(z, P, gen, dis, alpha, ret_e=True)
        # step_lr = b_lr + (e_lr - b_lr) / decay_steps
        z = z - step_lr * grad[0] + eps        
        # import pdb
        # pdb.set_trace()
        z = Variable(xp.clip(z.data, -1, 1))
    z_sp.append(z)
    # print(n_steps, len(z_sp), z.shape)
    return z_sp


def langevin_sample(gen, dis, config, n=50000, batchsize=100):
    ims = []
    zs = []
    xp = gen.xp
    alpha = config.langevin['alpha']
    n_steps = config.langevin['n_steps']
    step_lr = config.langevin['step_lr']
    eps_std = config.langevin['eps_std']
    for i in range(0, n, batchsize):
        with chainer.using_config('train', False):
            z = Variable(gen.sample_z(batchsize))
            z_sp = langevin_dynamics(z, gen, dis, alpha, n_steps, step_lr, eps_std)
        x = gen(batchsize, z_sp[-1])
        x = chainer.cuda.to_cpu(x.data)
        ims.append(x)
        zs.append(np.stack([chainer.cuda.to_cpu(o.data) for o in z_sp], axis=0))
        # if i % 50 == 0:
            # print(i)
    ims = np.asarray(ims)
    zs = np.stack(zs, axis=0)
    return ims, zs

def exp_gaussian25():
    gen = Generator(n_hidden=2, noize='uni', non_linear=F.leaky_relu, final=F.identity)
    serializers.load_npz("pretrained_models/synthetic/G_25gaussians_WGAN-GP.npz", gen)
    dis = Discriminator(non_linear=F.leaky_relu, final=F.identity)
    serializers.load_npz("pretrained_models/synthetic/D_25gaussians_WGAN-GP.npz", dis)
    ret_path = 'logs/toy/gaussian'

    if gpu_device==0:
        gen.to_gpu()
        dis.to_gpu()

    

    n = 10000
    x_train = prepare_25gaussian_data(BATCH_SIZE=n)
    batchsize = n
    config = yaml_utils.Config(yaml.load(open('configs/synthetic.yml')))

    hs_alpha = [1.0, ]
    hs_step_lr = [1e-4, ]

    best_emd = 1e8
    best_ps = None
    xp = gen.xp
    cnt = 0

    emd1_trials = np.zeros((100, ))
    a, b = np.ones((n,)) / n, np.ones((n,)) / n                
    for iter_step in range(100):
        z_sp = gen.sample_z(n)
        x_sp = gen(n, z_sp).data
        x_sp = chainer.cuda.to_cpu(x_sp)
        # d1 = pairwise_distances(x_sp, x_train, metric='sqeuclidean')
        d1 = ot.dist(x_train, x_sp)
        emd1 = ot.emd2(a, b, d1)
        emd1_trials[iter_step] = emd1

    # import pdb
    # pdb.set_trace()
    # print('p_g emd', np.mean(emd1_trials))

    # z_sp = gen.sample_z(n)
    # x_sp = gen(n, z_sp).data
    # x_sp = chainer.cuda.to_cpu(x_sp)

    # # save_path = os.path.join(ret_path, 'gaussian25', '{:.4f}-{:.4f}-{}-{:.4f}.jpg'.format(*best_ps))

    # dot_x_sp = np.load('pretrained_models/synthetic/dot_25mog.npy')
    # for step in range(1):
    #     config.langevin['alpha'] = 1.0
    #     config.langevin['step_lr'] = 1e-3
    #     config.langevin['eps_std'] = 1e-2
    #     all_dls_x_sp, all_dls_z_sp = langevin_sample(gen, dis, config, n, batchsize)
    #     all_dls_z_sp = all_dls_z_sp.reshape((-1, n, 2))
    #     for i in range(all_dls_z_sp.shape[0]):
    #         dls_z_sp = xp.asarray(all_dls_z_sp[i])
    #         dls_x_sp = gen(n, dls_z_sp).data
    #         dls_x_sp = chainer.cuda.to_cpu(dls_x_sp)
                    
    #         # dls_x_sp = np.reshape(dls_x_sp, [-1, dls_x_sp.shape[-1]])
    #         d2 = pairwise_distances(x_train, dls_x_sp, metric='sqeuclidean')            
    #         emd2 = ot.emd2(a, b, d2)
    #         save_path = os.path.join(ret_path, 'paper', 'demo-{}-{}-{:.6f}.jpg'.format(step, i, emd2))
    #         if not os.path.exists(os.path.dirname(save_path)):
    #             os.makedirs(os.path.dirname(save_path))
            
    #         show_three_figures(x_sp, dot_x_sp, dls_x_sp, x_train, -2,2,-2,2, save_path=save_path)    
    # return


    data_rec = {}
    for alpha in hs_alpha:
        for step_lr in hs_step_lr:
            config.langevin['alpha'] = alpha
            config.langevin['step_lr'] = step_lr

            repeat_steps = 10
            emd_trials = np.zeros((repeat_steps, 1000))
            x_sp_trials = np.zeros((repeat_steps, 100, 1000, 2))
            for iter_step in range(repeat_steps):                                
                all_dls_x_sp, all_dls_z_sp = langevin_sample(gen, dis, config, n, batchsize)
                all_dls_z_sp = all_dls_z_sp.reshape((-1, n, 2))

                for i in range(all_dls_z_sp.shape[0]):
                    dls_z_sp = xp.asarray(all_dls_z_sp[i])
                    dls_x_sp = gen(n, dls_z_sp).data
                    dls_x_sp = chainer.cuda.to_cpu(dls_x_sp)
                    
                    # dls_x_sp = np.reshape(dls_x_sp, [-1, dls_x_sp.shape[-1]])
                    d2 = pairwise_distances(x_train, dls_x_sp, metric='sqeuclidean')
                    emd2 = ot.emd2(a, b, d2)
                    emd_trials[iter_step, i] = emd2
                    x_sp_trials[iter_step, i] = dls_x_sp

                cnt += 1
                print('progress ', cnt / (len(hs_alpha) * len(hs_step_lr) * repeat_steps))
            # print(emd_trials[:, :3])
            emd_trial = emd_trials.mean(0)
            data_rec['{},{}'.format(alpha, step_lr)] = emd_trials
            # print(emd_trials.shape)
            for i in range(all_dls_z_sp.shape[0]):
                emd2 = emd_trial[i]
                if emd2 < best_emd:
                    best_emd = emd2
                    best_ps = (alpha, step_lr, i, emd2)
                    k = 0
                    for j in range(1, repeat_steps):
                        if emd_trials[j, i] < emd_trials[k, i]:
                            k = j
                    best_sp = x_sp_trials[k, i]
    
    # import pdb
    # pdb.set_trace()
    z_sp = gen.sample_z(n)
    x_sp = gen(n, z_sp).data
    x_sp = chainer.cuda.to_cpu(x_sp)

    # save_path = os.path.join(ret_path, 'gaussian25', '{:.4f}-{:.4f}-{}-{:.4f}.jpg'.format(*best_ps))
    save_path = os.path.join(ret_path, 'paper', 'demo-{:.4f}-{:.6f}-{}-{:.4f}.jpg'.format(*best_ps))
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    print(best_ps)
    dot_x_sp = np.load('pretrained_models/synthetic/dot_25mog.npy')
    show_three_figures(x_sp, dot_x_sp, best_sp, x_train, -2,2,-2,2, save_path=save_path)
    np.save(os.path.join(ret_path, 'gaussian25', 'ret-demo'), data_rec)

def exp_swissroll():
    gen = Generator(n_hidden=2, noize='uni', non_linear=F.leaky_relu, final=F.identity)
    serializers.load_npz("pretrained_models/synthetic/G_swissroll_WGAN-GP.npz", gen)
    dis = Discriminator(non_linear=F.leaky_relu, final=F.identity)
    serializers.load_npz("pretrained_models/synthetic/D_swissroll_WGAN-GP.npz", dis)
    ret_path = 'logs/toy/swissroll'

    if gpu_device==0:
        gen.to_gpu()
        dis.to_gpu()

    x_train = prepare_swissroll_data(BATCH_SIZE=1000)

    n = 10000
    batchsize = n
    config = yaml_utils.Config(yaml.load(open('configs/synthetic.yml')))

    hs_alpha = [10]
    hs_step_lr = [1e-4]

    best_emd = 1e8
    best_ps = None
    xp = gen.xp
    cnt = 0

    emd1_trials = np.zeros((100, ))
    a, b = np.ones((n,)) / n, np.ones((n,)) / n                
    for iter_step in range(100):
        z_sp = gen.sample_z(n)
        x_sp = gen(n, z_sp).data
        x_sp = chainer.cuda.to_cpu(x_sp)
        # d1 = pairwise_distances(x_sp, x_train, metric='sqeuclidean')
        d1 = ot.dist(x_sp, x_train)
        emd1 = ot.emd2(a, b, d1)
        emd1_trials[iter_step] = emd1

    print('p_g emd', np.mean(emd1_trials))
    
    n = 5000
    z_sp = gen.sample_z(n)
    x_sp = gen(n, z_sp).data
    x_sp = chainer.cuda.to_cpu(x_sp)
    # dot_x_sp = np.load('pretrained_models/synthetic/dot_swiss.npy')
    # for step in range(1):
    #     config.langevin['alpha'] = 1.0
    #     config.langevin['step_lr'] = 1e-4
    #     config.langevin['eps_std'] = 2e-4
    #     all_dls_x_sp, all_dls_z_sp = langevin_sample(gen, dis, config, n, batchsize)
    #     all_dls_z_sp = all_dls_z_sp.reshape((-1, n, 2))
    #     for i in range(all_dls_z_sp.shape[0]):
    #         dls_z_sp = xp.asarray(all_dls_z_sp[i])
    #         dls_x_sp = gen(n, dls_z_sp).data
    #         dls_x_sp = chainer.cuda.to_cpu(dls_x_sp)
                    
    #         # dls_x_sp = np.reshape(dls_x_sp, [-1, dls_x_sp.shape[-1]])
    #         # d2 = pairwise_distances(x_train, dls_x_sp, metric='sqeuclidean')            
    #         # emd2 = ot.emd2(a, b, d2)
    #         emd2 = 0
    #         save_path = os.path.join(ret_path, 'paper', 'demo-{}-{}-{:.6f}.jpg'.format(step, i, emd2))
    #         if not os.path.exists(os.path.dirname(save_path)):
    #             os.makedirs(os.path.dirname(save_path))
            
    #         show_three_figures(x_sp, dot_x_sp, dls_x_sp, x_train, -2,2.5,-2,2.5, save_path=save_path)    
    # return



    data_rec = {}
    for alpha in hs_alpha:
        for step_lr in hs_step_lr:
            config.langevin['alpha'] = alpha
            config.langevin['step_lr'] = step_lr

            repeat_steps = 10
            emd_trials = np.zeros((repeat_steps, 1000))

            for iter_step in range(repeat_steps):                                
                all_dls_x_sp, all_dls_z_sp = langevin_sample(gen, dis, config, n, batchsize)
                all_dls_z_sp = all_dls_z_sp.reshape((-1, n, 2))

                for i in range(all_dls_z_sp.shape[0]):
                    dls_z_sp = xp.asarray(all_dls_z_sp[i])
                    dls_x_sp = gen(n, dls_z_sp).data
                    dls_x_sp = chainer.cuda.to_cpu(dls_x_sp)
                    
                    # dls_x_sp = np.reshape(dls_x_sp, [-1, dls_x_sp.shape[-1]])
                    d2 = pairwise_distances(x_train, dls_x_sp, metric='sqeuclidean')
                    emd2 = ot.emd2(a, b, d2)
                    emd_trials[iter_step, i] = emd2
                
                cnt += 1
                print('progress ', cnt / (len(hs_alpha) * len(hs_step_lr) * repeat_steps))
            # print(emd_trials[:, :3])
            emd_trials = emd_trials.mean(0)
            data_rec['{},{}'.format(alpha, step_lr)] = emd_trials
            # print(emd_trials.shape)
            for i in range(all_dls_z_sp.shape[0]):
                emd2 = emd_trials[i]
                if emd2 < best_emd:
                    best_emd = emd2
                    best_ps = (alpha, step_lr, i, emd2)
                    best_sp = dls_x_sp
    
    # import pdb
    # pdb.set_trace()
    z_sp = gen.sample_z(n)
    x_sp = gen(n, z_sp).data
    x_sp = chainer.cuda.to_cpu(x_sp)

    save_path = os.path.join(ret_path, 'swissroll', '{:.4f}-{:.4f}-{}-{:.4f}.jpg'.format(*best_ps))
    # save_path = os.path.join(ret_path, 'swissroll', 'demo.jpg'.format(*best_ps))
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    print(best_ps)
    show_three_figures(x_sp, dls_x_sp, dls_x_sp, x_train, -2,2.5,-2,2.5, save_path=save_path)
    np.save(os.path.join(ret_path, 'swissroll', 'ret'), data_rec)

if __name__ == "__main__":
    # exp_gaussian25()
    exp_swissroll()